import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data as gd
import torch_geometric.nn as gnn
from torch_geometric.utils import add_self_loops

from gflownet.envs.graph_building_env import GraphActionCategorical


def mlp(n_in, n_hid, n_out, n_layer, act=nn.LeakyReLU, keep_last_act=False):
    """Creates a fully-connected network with no activation after the last layer.
    If `n_layer` is 0 then this corresponds to `nn.Linear(n_in, n_out)`.
    """
    n = [n_in] + [n_hid] * n_layer + [n_out]
    last = [-1, None][keep_last_act]
    return nn.Sequential(*sum([[nn.Linear(n[i], n[i + 1]), act()] for i in range(n_layer + 1)], [])[:last])


class GraphTransformer(nn.Module):
    """An agnostic GraphTransformer class, and the main model used by other model classes

    This graph model takes in node features, edge features, and graph features (referred to as
    conditional information, since they condition the output). The graph features are projected to
    virtual nodes (one per graph), which are fully connected.

    The per node outputs are the concatenation of the final (post graph-convolution) node embeddings
    and of the final virtual node embedding of the graph each node corresponds to.

    The per graph outputs are the concatenation of a global mean pooling operation, of the final
    virtual node embeddings, and of the conditional information embedding.
    """
    def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2):
        """
        Parameters
        ----------
        x_dim: int
            The number of node features
        e_dim: int
            The number of edge features
        g_dim: int
            The number of graph-level features
        num_emb: int
            The number of hidden dimensions, i.e. embedding size. Default 64.
        num_layers: int
            The number of Transformer layers.
        num_heads: int
            The number of Transformer heads per layer.
        """
        super().__init__()
        self.num_layers = num_layers

        self.x2h = mlp(x_dim, num_emb, num_emb, 2)
        self.e2h = mlp(e_dim, num_emb, num_emb, 2)
        self.c2h = mlp(g_dim, num_emb, num_emb, 2)
        self.graph2emb = nn.ModuleList(
            sum([[
                gnn.GENConv(num_emb, num_emb, num_layers=1, aggr='add', norm=None),
                gnn.TransformerConv(num_emb * 2, num_emb, edge_dim=num_emb, heads=num_heads),
                nn.Linear(num_heads * num_emb, num_emb),
                gnn.LayerNorm(num_emb, affine=False),
                mlp(num_emb, num_emb * 4, num_emb, 1),
                gnn.LayerNorm(num_emb, affine=False),
            ] for i in range(self.num_layers)], []))

    def forward(self, g: gd.Batch, cond: torch.Tensor):
        """Forward pass

        Parameters
        ----------
        g: gd.Batch
            A standard torch_geometric Batch object. Expects `edge_attr` to be set.
        cond: torch.Tensor
            The per-graph conditioning information. Shape: (g.num_graphs, self.g_dim).

        Returns
        node_embeddings: torch.Tensor
            Per node embeddings. Shape: (g.num_nodes, self.num_emb * 2).
        graph_embeddings: torch.Tensor
            Per graph embeddings. Shape: (g.num_graphs, self.num_emb * 3).
        """
        o = self.x2h(g.x)
        e = self.e2h(g.edge_attr)
        c = self.c2h(cond)
        num_total_nodes = g.x.shape[0]
        # Augment the edges with a new edge to the conditioning
        # information node. This new node is connected to every node
        # within its graph.
        u, v = torch.arange(num_total_nodes, device=o.device), g.batch + num_total_nodes
        aug_edge_index = torch.cat([g.edge_index, torch.stack([u, v]), torch.stack([v, u])], 1)
        e_p = torch.zeros((num_total_nodes * 2, e.shape[1]), device=g.x.device)
        e_p[:, 0] = 1  # Manually create a bias term
        aug_e = torch.cat([e, e_p], 0)
        aug_edge_index, aug_e = add_self_loops(aug_edge_index, aug_e, 'mean')
        aug_batch = torch.cat([g.batch, torch.arange(c.shape[0], device=o.device)], 0)

        # Append the conditioning information node embedding to o
        o = torch.cat([o, c], 0)
        for i in range(self.num_layers):
            # Run the graph transformer forward
            gen, trans, linear, norm1, ff, norm2 = self.graph2emb[i * 6:(i + 1) * 6]
            agg = gen(o, aug_edge_index, aug_e)
            o = norm1(o + linear(trans(torch.cat([o, agg], 1), aug_edge_index, aug_e)), aug_batch)
            o = norm2(o + ff(o), aug_batch)

        glob = torch.cat([gnn.global_mean_pool(o[:-c.shape[0]], g.batch), o[-c.shape[0]:], c], 1)
        o_final = torch.cat([o[:-c.shape[0]], c[g.batch]], 1)
        return o_final, glob


class GraphTransformerGFN(nn.Module):
    """GraphTransformer class for a GFlowNet which outputs a GraphActionCategorical. Meant for atom-wise
    generation.

    Outputs logits for the following actions
    - Stop
    - AddNode
    - SetNodeAttr
    - AddEdge
    - SetEdgeAttr

    """
    def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2):
        """See `GraphTransformer` for argument values"""
        super().__init__()
        self.transf = GraphTransformer(x_dim=env_ctx.num_node_dim, e_dim=env_ctx.num_edge_dim,
                                       g_dim=env_ctx.num_cond_dim, num_emb=num_emb, num_layers=num_layers,
                                       num_heads=num_heads)
        num_final = num_emb * 2
        num_mlp_layers = 0
        self.emb2add_edge = mlp(num_final, num_emb, 1, num_mlp_layers)
        self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values, num_mlp_layers)
        self.emb2set_node_attr = mlp(num_final, num_emb, env_ctx.num_node_attr_logits, num_mlp_layers)
        self.emb2set_edge_attr = mlp(num_final, num_emb, env_ctx.num_edge_attr_logits, num_mlp_layers)
        self.emb2stop = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
        self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
        self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
        self.action_type_order = env_ctx.action_type_order

    def forward(self, g: gd.Batch, cond: torch.Tensor):
        node_embeddings, graph_embeddings = self.transf(g, cond)
        ne_row, ne_col = g.non_edge_index
        # On `::2`, edges are duplicated to make graphs undirected, only take the even ones
        e_row, e_col = g.edge_index[:, ::2]
        cat = GraphActionCategorical(
            g,
            logits=[
                self.emb2stop(graph_embeddings),
                self.emb2add_node(node_embeddings),
                self.emb2set_node_attr(node_embeddings),
                self.emb2add_edge(node_embeddings[ne_row] + node_embeddings[ne_col]),
                self.emb2set_edge_attr(node_embeddings[e_row] + node_embeddings[e_col]),
            ],
            keys=[None, 'x', 'x', 'non_edge_index', 'edge_index'],
            types=self.action_type_order,
        )
        return cat, self.emb2reward(graph_embeddings)


class GraphTransformerFragGFN(nn.Module):
    """GraphTransformer class for a GFlowNet which outputs a GraphActionCategorical. Meant for
    fragment-wise generation.

    Outputs logits for the following actions
    - Stop
    - AddNode
    - SetEdgeAttr
    """
    def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2):
        super().__init__()
        self.transf = GraphTransformer(x_dim=env_ctx.num_node_dim, e_dim=env_ctx.num_edge_dim,
                                       g_dim=env_ctx.num_cond_dim, num_emb=num_emb, num_layers=num_layers,
                                       num_heads=num_heads)
        num_final = num_emb * 2
        num_mlp_layers = 0
        self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values, num_mlp_layers)
        # Edge attr logits are "sided", so we will compute both sides independently
        self.emb2set_edge_attr = mlp(num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2, num_mlp_layers)
        self.emb2stop = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
        self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
        self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers)
        self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
        self.action_type_order = env_ctx.action_type_order
        self.do_mask = True

    def forward(self, g: gd.Batch, cond: torch.Tensor):
        """See `GraphTransformer` for argument values"""
        node_embeddings, graph_embeddings = self.transf(g, cond)
        # On `::2`, edges are duplicated to make graphs undirected, only take the even ones
        e_row, e_col = g.edge_index[:, ::2]
        edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col])
        src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_row]], 1))
        dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_col]], 1))

        def _mask(x, m):
            # mask logit vector x with binary mask m, -1000 is a tiny log-value
            if self.do_mask:
                return x * m + -1000 * (1 - m)
            return x

        cat = GraphActionCategorical(
            g,
            logits=[
                self.emb2stop(graph_embeddings),
                _mask(self.emb2add_node(node_embeddings), g.add_node_mask),
                _mask(torch.cat([src_anchor_logits, dst_anchor_logits], 1), g.set_edge_attr_mask),
            ],
            keys=[None, 'x', 'edge_index'],
            types=self.action_type_order,
        )
        return cat, self.emb2reward(graph_embeddings)


class GraphTransformerFragEnvelopeQL(nn.Module):
    """GraphTransformer class for an EnvelopeQLearning agent

    Outputs Qs for the following actions
    - Stop
    - AddNode
    - SetEdgeAttr
    """
    def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objectives=2):
        super().__init__()
        self.transf = GraphTransformer(x_dim=env_ctx.num_node_dim, e_dim=env_ctx.num_edge_dim,
                                       g_dim=env_ctx.num_cond_dim, num_emb=num_emb, num_layers=num_layers,
                                       num_heads=num_heads)
        num_final = num_emb * 2
        num_mlp_layers = 0
        self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers)
        # Edge attr logits are "sided", so we will compute both sides independently
        self.emb2set_edge_attr = mlp(num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives,
                                     num_mlp_layers)
        self.emb2stop = mlp(num_emb * 3, num_emb, env_ctx.num_stop_logits * num_objectives, num_mlp_layers)
        self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
        self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers)
        self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
        self.action_type_order = env_ctx.action_type_order
        self.mask_value = -10
        self.num_objectives = num_objectives

    def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
        """See `GraphTransformer` for argument values"""
        node_embeddings, graph_embeddings = self.transf(g, cond)
        # On `::2`, edges are duplicated to make graphs undirected, only take the even ones
        e_row, e_col = g.edge_index[:, ::2]
        edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col])
        src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_row]], 1))
        dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_col]], 1))

        def _mask(x, m):
            # mask logit vector x with binary mask m
            return x * m + self.mask_value * (1 - m)

        def _mask_obj(x, m):
            # mask logit vector x with binary mask m
            return (x.reshape(x.shape[0], x.shape[1] // self.num_objectives, self.num_objectives) * m[:, :, None] +
                    self.mask_value * (1 - m[:, :, None])).reshape(x.shape)

        cat = GraphActionCategorical(
            g,
            logits=[
                F.relu(self.emb2stop(graph_embeddings)),
                _mask(F.relu(self.emb2add_node(node_embeddings)), g.add_node_mask),
                _mask_obj(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), g.set_edge_attr_mask),
            ],
            keys=[None, 'x', 'edge_index'],
            types=self.action_type_order,
        )
        r_pred = self.emb2reward(graph_embeddings)
        if output_Qs:
            return cat, r_pred
        cat.masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()]
        # Compute the greedy policy
        # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
        # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
        w = cond[:, -self.num_objectives:]
        w_dot_Q = [(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
                   for qi, b in zip(cat.logits, cat.batch)]
        # Set the softmax distribution to a very low temperature to make sure only the max gets
        # sampled (and we get random argmax tie breaking for free!):
        cat.logits = [i * 100 for i in w_dot_Q]
        return cat, r_pred


class GraphTransformerEnvelopeQL(nn.Module):
    def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objectives=2):
        """See `GraphTransformer` for argument values"""
        super().__init__()
        self.transf = GraphTransformer(x_dim=env_ctx.num_node_dim, e_dim=env_ctx.num_edge_dim,
                                       g_dim=env_ctx.num_cond_dim, num_emb=num_emb, num_layers=num_layers,
                                       num_heads=num_heads)
        num_final = num_emb * 2
        num_mlp_layers = 0
        self.emb2add_edge = mlp(num_final, num_emb, num_objectives, num_mlp_layers)
        self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers)
        self.emb2set_node_attr = mlp(num_final, num_emb, env_ctx.num_node_attr_logits * num_objectives, num_mlp_layers)
        self.emb2set_edge_attr = mlp(num_final, num_emb, env_ctx.num_edge_attr_logits * num_objectives, num_mlp_layers)
        self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers)
        self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
        self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
        self.action_type_order = env_ctx.action_type_order
        self.num_objectives = num_objectives

    def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
        node_embeddings, graph_embeddings = self.transf(g, cond)
        ne_row, ne_col = g.non_edge_index
        # On `::2`, edges are duplicated to make graphs undirected, only take the even ones
        e_row, e_col = g.edge_index[:, ::2]
        cat = GraphActionCategorical(
            g,
            logits=[
                self.emb2stop(graph_embeddings),
                self.emb2add_node(node_embeddings),
                self.emb2set_node_attr(node_embeddings),
                self.emb2add_edge(node_embeddings[ne_row] + node_embeddings[ne_col]),
                self.emb2set_edge_attr(node_embeddings[e_row] + node_embeddings[e_col]),
            ],
            keys=[None, 'x', 'x', 'non_edge_index', 'edge_index'],
            types=self.action_type_order,
        )
        r_pred = self.emb2reward(graph_embeddings)
        if output_Qs:
            return cat, r_pred
        # Compute the greedy policy
        # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
        # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
        w = cond[:, -self.num_objectives:]
        w_dot_Q = [(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
                   for qi, b in zip(cat.logits, cat.batch)]
        # Set the softmax distribution to a very low temperature to make sure only the max gets
        # sampled (and we get random argmax tie breaking for free!):
        cat.logits = [i * 100 for i in w_dot_Q]
        return cat, r_pred
